classdef CalibrateKrusellWithRobotsEconomy < handle
    % Implements calibration of KrusellWithRobotsEconomy.
    
    properties
        dist % skill and participation distribution object
        varnames % variables over which to optimize
        targets % struct of calibration targets
        wagebins % struct of wagebins for the three sectors
        percentiles % vector of percentiles to compute
        w_lb
        w_ub
    end
    
    properties (SetAccess = protected)
        num_nodes_f
        num_nodes_g
        num_nodes_bins
        num_nodes_phi
        num_nodes_percentiles
        
        nodes_f
        % pre-computed nodes for integration over :math:`f_i(w)`
        weights_f
        % pre-computed weights for integration over :math:`f_i(w)`
        
        nodes_percentiles
        weights_percentiles
    end
    
    methods
        function self = CalibrateKrusellWithRobotsEconomy(dist,...
                targets, ...
                wagebins, ...
                varnames, w_lb, w_ub, varargin)
            
            
            valid_num_nodes = @(x) isnumeric(x) && x==floor(x) && ...
                (x>0);
            
            p = inputParser;
            p.addRequired('dist',@isobject)
            p.addRequired('targets',@isstruct)
            p.addRequired('wagebins',@isstruct)
            p.addRequired('varnames',@iscell)
            p.addRequired('w_lb',@(x) x>0)
            p.addRequired('w_ub',@(x) x>0)
            p.addParameter('percentiles',[0.1:0.1:0.9], @(x) min(x)>=0 ...
                && max(x)<=1)
            p.addParameter('num_nodes_f',10, valid_num_nodes)
            p.addParameter('num_nodes_g',15, valid_num_nodes)
            p.addParameter('num_nodes_bins',10, valid_num_nodes)
            p.addParameter('num_nodes_phi',50, valid_num_nodes)
            p.addParameter('num_nodes_percentiles',100, valid_num_nodes)
            
            p.parse(dist, targets, wagebins, varnames,w_lb, w_ub,varargin{:});
            
            self.dist = p.Results.dist;
            self.targets = p.Results.targets;
            self.wagebins = p.Results.wagebins;
            self.varnames = p.Results.varnames;
            self.w_lb = p.Results.w_lb;
            self.w_ub = p.Results.w_ub;
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_phi = p.Results.num_nodes_phi;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            self.num_nodes_percentiles = p.Results.num_nodes_percentiles;
            
            [self.nodes_f, self.weights_f] = ...
                tools.Integration.lgwt(self.num_nodes_f,-1,1);
            [self.nodes_percentiles, self.weights_percentiles] = ...
                tools.Integration.lgwt(self.num_nodes_percentiles,-1,1);
            
            
        end
        
        
        function [sol,fval,exitflag,output] = calibrate_skill_dist(self, var0, lb, ub)
            % unconstrained calibration, no longer used
            obj = @(x) self.objective_skill_dist_lsqnonlin(tools.translators.x2s(self,x));
            x0 = tools.translators.s2x(self,var0);
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            knitrooptions = optimset( 'Display','iter');
            options_file = 'CalibrateKrusellWithRobotsEconomyKnitro.opt';
            [x,fval,exitflag,output] = knitromatlab_lsqnonlin(obj, x0, lb,ub, ...
                [],knitrooptions, ...
                options_file);
            
            sol = tools.translators.x2s(self,x);
        end
        
        function [sol,fval,exitflag,output] = calibrate_skill_dist_cons(self, var0, lb, ub)
            Amat = [];
            b = [];
            Aeq = [];
            beq = [];
            
            obj = @(x) ...
                self.objective_skill_dist_lsqnonlin(tools.translators.x2s(self,x))' *...
                self.objective_skill_dist_lsqnonlin(tools.translators.x2s(self,x));
            nonlcon = @(x) self.nonlcon(tools.translators.x2s(self,x));
            x0 = tools.translators.s2x(self,var0);
            
            lb = tools.translators.s2x(self,lb);
            ub = tools.translators.s2x(self,ub);
            
            knitrooptions = optimset( 'Display','iter');
            options_file = 'CalibrateKrusellWithRobotsEconomyKnitro_con.opt';
            [x,fval,exitflag,output] = knitromatlab(obj, x0, Amat,b,Aeq,beq,lb,ub, ...
                nonlcon,[],knitrooptions, ...
                options_file);
            
            sol = tools.translators.x2s(self,x);
        end
        
        
        function [c, ceq] = nonlcon(self,inp)
            mod = self.moments_skill_prices(inp);
            ceq = [];
            % make sure order of average wages is not reversed
            c = [mod.average_wage_M_1 - mod.average_wage_R_1,...
                mod.average_wage_R_1 - mod.average_wage_C_1];
        end
        
        function out = objective_skill_dist_lsqnonlin(self, inp)
            targ = self.targets;
            
            try
                mod = self.moments_skill_prices(inp);
                
                
                residuals = [targ.freqM.weight.*(targ.freqM.val - mod.freqM);...
                    targ.freqR.weight.*(targ.freqR.val - mod.freqR);...
                    targ.freqC.weight.*(targ.freqC.val - mod.freqC);
                    targ.participation_rate.weight.*(targ.participation_rate.val - mod.participation_rate_0);...
                    targ.part_M.weight.*(targ.part_M.val - mod.part_M_0);...
                    targ.part_R.weight.*(targ.part_R.val - mod.part_R_0);...
                    targ.part_C.weight.*(targ.part_C.val - mod.part_C_0);...
                    targ.eps_w_K.weight.*(targ.eps_w_K.val - mod.eps_w_K);...
                    targ.part_M_change.weight.*(targ.part_M_change.val - mod.part_M_change);...
                    targ.part_R_change.weight.*(targ.part_R_change.val - mod.part_R_change);...
                    targ.part_C_change.weight.*(targ.part_C_change.val - mod.part_C_change);...
                    targ.avg_labor_income.weight.* ...
                    (targ.avg_labor_income.val - mod.avg_labor_income);...
                    targ.transfer_lab_inc_share.weight.* ...
                    (targ ...
                    .transfer_lab_inc_share.val - mod.transfer_lab_inc_share)];
                
                if mod.part_M_0 * mod.part_R_0 * mod.part_C_0 < 1e-6
                    % penality for zero participation in at least
                    % one of the sectors
                    out = inf .* ones(size(residuals));
                else
                    out = residuals;
                end
                
            catch ME
                out = NaN .* ones(size(targets));
            end
            
        end
        
        
        function out = moments_skill_prices(self, inp)
            inp.mu_part_R = inp.mu_part_M;
            inp.mu_part_C = inp.mu_part_M;
            
            inp.sigma_part_R = inp.sigma_part_M;
            inp.sigma_part_C = inp.sigma_part_M;
            
            w_lb = min(self.w_lb);
            w_ub = max(self.w_ub);
            
            % assign vars to ease notation
            Y_M_0 = inp.Y_M_0;
            Y_R_0 = inp.Y_R_0;
            Y_C_0 = inp.Y_C_0;
            
            Y_M_1 = inp.Y_M_1;
            Y_R_1 = inp.Y_R_1;
            Y_C_1 = inp.Y_C_1;
            
            % relative change in robots
            K_B_0 = 0.36;
            K_B_1 = 1.1;
            out.K_B_rel_change = (K_B_1 - K_B_0)/K_B_0;
            
            % prelims for integration
            nodes = self.nodes_f(:);
            weights = self.weights_f(:);
            
            w_nodes = tools.ChangeOfVariables.change_variable_minone_one_log(nodes,w_lb,w_ub);
            dw_dnodes = tools.ChangeOfVariables.change_variable_minone_one_log_deriv(nodes,w_lb,w_ub);
            
            % evaluate relevant functions at nodes
            h_M_w_nodes_0 = self.dist.h_M(w_nodes, Y_M_0, Y_R_0, Y_C_0, inp);
            h_R_w_nodes_0 = self.dist.h_R(w_nodes, Y_M_0, Y_R_0, Y_C_0, inp);
            h_C_w_nodes_0 = self.dist.h_C(w_nodes, Y_M_0, Y_R_0, Y_C_0, inp);
            h_w_nodes_0 = h_M_w_nodes_0 + h_R_w_nodes_0 + h_C_w_nodes_0;
            
            h_M_w_nodes_1 = self.dist.h_M(w_nodes, Y_M_1, Y_R_1, Y_C_1, inp);
            h_R_w_nodes_1 = self.dist.h_R(w_nodes, Y_M_1, Y_R_1, Y_C_1, inp);
            h_C_w_nodes_1 = self.dist.h_C(w_nodes, Y_M_1, Y_R_1, Y_C_1, inp);
            h_w_nodes_1 = h_M_w_nodes_1 + h_R_w_nodes_1 + h_C_w_nodes_1;
            
            % compute moments
            out.participation_rate_0 = (h_w_nodes_0 .* dw_dnodes)' ...
                * weights;
            
            out.participation_rate_1 = (h_w_nodes_1 .* dw_dnodes)' * weights;
            
            out.participation_rate_rel_change = (out.participation_rate_1 ...
                - out .participation_rate_0)/out.participation_rate_0;
            
            % share of individuals (in terms of population) who
            % participate and are in a certain sector
            out.part_M_0 = (h_M_w_nodes_0 .* dw_dnodes)' * weights;
            out.part_R_0 = (h_R_w_nodes_0 .* dw_dnodes)' * weights;
            out.part_C_0 = (h_C_w_nodes_0 .* dw_dnodes)' * weights;
            
            out.part_M_1 = (h_M_w_nodes_1 .* dw_dnodes)' * weights;
            out.part_R_1 = (h_R_w_nodes_1 .* dw_dnodes)' * weights;
            out.part_C_1 = (h_C_w_nodes_1 .* dw_dnodes)' * weights;
            
            out.part_M_rel_change = (out.part_M_1 - out.part_M_0)/out.part_M_0;
            out.part_R_rel_change = (out.part_R_1 - out.part_R_0)/out.part_R_0;
            out.part_C_rel_change = (out.part_C_1 - out.part_C_0)/out.part_C_0;
            
            out.part_M_change = out.part_M_1 - out.part_M_0;
            out.part_R_change = out.part_R_1 - out.part_R_0;
            out.part_C_change = out.part_C_1 - out.part_C_0;
            
            
            % employment
            out.emp_M_0 = out.part_M_0/out.participation_rate_0;
            out.emp_R_0 = out.part_R_0/out.participation_rate_0;
            out.emp_C_0 = out.part_C_0/out.participation_rate_0;
            
            % average wages
            out.average_wage_M_0 = 1/out.part_M_0 *...
                (w_nodes .* h_M_w_nodes_0.* dw_dnodes)' * weights;
            
            out.average_wage_R_0 = 1/out.part_R_0 *...
                (w_nodes .* h_R_w_nodes_0.* dw_dnodes)' * weights;
            
            out.average_wage_C_0 = 1/out.part_C_0 *...
                (w_nodes .* h_C_w_nodes_0.* dw_dnodes)' * weights;
            
            out.average_wage_M_1 = 1/out.part_M_1 *...
                (w_nodes .* h_M_w_nodes_1.* dw_dnodes)' * weights;
            
            out.average_wage_R_1 = 1/out.part_R_1 *...
                (w_nodes .* h_R_w_nodes_1.* dw_dnodes)' * weights;
            
            out.average_wage_C_1 = 1/out.part_C_1 *...
                (w_nodes .* h_C_w_nodes_1.* dw_dnodes)' * weights;
            
            
            % mass in wage bins in period 0, used to calibrate
            % skill distribution
            out.freqM = self.dist.h_M_mass_wagebins(Y_M_0, Y_R_0, Y_C_0, ...
                self.wagebins.M, out.participation_rate_0, inp);
            out.freqR = self.dist.h_R_mass_wagebins(Y_M_0, Y_R_0, Y_C_0, ...
                self.wagebins.R, out.participation_rate_0, inp);
            out.freqC = self.dist.h_C_mass_wagebins(Y_M_0, Y_R_0, Y_C_0, ...
                self.wagebins.C, out.participation_rate_0, ...
                inp);
            
            % participation elasticity wrt. robots
            out.eps_participation_K_B = out.participation_rate_rel_change/out.K_B_rel_change;
            out.eps_part_M_K_B = out.part_M_rel_change/out.K_B_rel_change;
            out.eps_part_R_K_B = out.part_R_rel_change/out.K_B_rel_change;
            out.eps_part_C_K_B = out.part_C_rel_change/out.K_B_rel_change; ...
                
        
        pct = 0.1:0.1:0.9;
        
        handle_h_0 = @(w) 1/out.participation_rate_0 .* self.dist.h(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        handle_h_1 = @(w) 1/out.participation_rate_1 .* self.dist.h(w,inp.Y_M_1,inp.Y_R_1,inp.Y_C_1,inp);
        
        out.deciles_0 = ...
            tools.Integration.compute_percentiles(handle_h_0, ...
            pct,1e-4,1e4,self.num_nodes_percentiles,self.nodes_percentiles,self.weights_percentiles);
        out.deciles_1 = ...
            tools.Integration.compute_percentiles(handle_h_1,pct,1e-4,1e4,self.num_nodes_percentiles,self.nodes_percentiles,self.weights_percentiles);
        out.eps_w_K = (out.deciles_1 - out.deciles_0)./out.deciles_0 ...
            ./ out.K_B_rel_change;
        
        % mass by occupation
        integrand = @(w) self.dist.dens_skill.f_M(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.mass_M_0 = ...
            tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        
        integrand = @(w) self.dist.dens_skill.f_R(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.mass_R_0 = ...
            tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        
        integrand = @(w) self.dist.dens_skill.f_C(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.mass_C_0 = tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        out.unconditional_labor_income = self.dist.total_labor_income(inp.Y_M_0,inp.Y_R_0, ...
            inp.Y_C_0, w_lb, w_ub, ...
            inp);
        out.avg_labor_income = 1/out.participation_rate_0 * ...
            out.unconditional_labor_income;
        
        
        % marginal tax rate at average income
        out.mtax_at_avg_income = ...
            self.dist.mtax_hsv_y(out.avg_labor_income * self.dist.scale_inc,inp.lambda_hsv,inp.tau_hsv);
        
        % average marginal income tax
        epsilon = self.dist.util.epsilon;
        integrand = @(w) 1/out.participation_rate_0.* self.dist.mtax_hsv(w, epsilon, ...
            inp.lambda_hsv, ...
            inp.tau_hsv, ...
            inp.xi) ...
            .* self.dist.h(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.avg_mtax = ...
            tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        
        % income-weighted avg. mtax
        integrand = @(w) 1/out.participation_rate_0.* ...
            self.dist.y_hsv(w, epsilon, ...
            inp.lambda_hsv, ...
            inp.tau_hsv, ...
            inp.xi).*...
            self.dist.mtax_hsv(w, epsilon, ...
            inp.lambda_hsv, ...
            inp.tau_hsv, ...
            inp.xi) ...
            .* self.dist.h(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.avg_mtax_incweighted = ...
            tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f)./out.avg_labor_income;
        
        integrand = @(w) self.dist.tax_hsv(w, epsilon, ...
            inp.lambda_hsv, ...
            inp.tau_hsv, ...
            inp.xi) ...
            .* self.dist.h(w,inp.Y_M_0,inp.Y_R_0,inp.Y_C_0,inp);
        out.income_tax_revenue = ...
            tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        
        % when computing the average tax rate, use average tax
        % revenue conditional on participation rather than
        % total tax revenue
        out.avg_tax_rate_0 = 1/out.participation_rate_0*out.income_tax_revenue/ ...
            out.avg_labor_income;
        
        out.transfer_lab_inc_share = inp.transfer ...
            / out.avg_labor_income;
        
        end
        
    end
    
end

